﻿#include "Mnist.h"
//#include "opencv.hpp"
#include <iostream>

//using namespace cv;

LinearSigImpl::LinearSigImpl(int input, int output) :
    ln(nullptr), sn(nullptr)
{
    ln = register_module("ln", torch::nn::Linear(input, output));
    sn = register_module("sn", torch::nn::ReLU());
}

torch::Tensor LinearSigImpl::forward(torch::Tensor x)
{
    x = ln->forward(x);
    x = sn->forward(x);
    return x;
}

Mlp::Mlp(int input, int outputCount) :
    ln1(nullptr), ln2(nullptr), output(nullptr)
{
    const int layer[] = { 256, 128 };
    ln1 = register_module("ln1", LinearSig(input, layer[0]));
    ln2 = register_module("ln2", LinearSig(layer[0], layer[1]));
    output = register_module("ln3", torch::nn::Linear(layer[1], outputCount));
}

torch::Tensor Mlp::forward(torch::Tensor x)
{
    x = ln1->forward(x);
    x = ln2->forward(x);
    x = output->forward(x);
    return x;
}

//---------------------------------------------------------------------------------------
// 运行时预测的函数。该函数对网络输出加工了一下可以节省代码
// x：输入数据。每一行代表一个输入数据
// 返回值：int是第几类，float是置信度
//---------------------------------------------------------------------------------------
vector<pair<int, float>> Mlp::predict(torch::Tensor x)
{
    torch::Tensor y = forward(x);
    torch::nn::Softmax softmax(1);
    y = softmax(y);
    tuple<torch::Tensor, torch::Tensor> value = y.max(1);
    vector<pair<int, float>> results;
    int rows = x.size(0);
    results.reserve(rows);
    for (int i = 0; i < rows; i++)
    {
        float maxv = std::get<0>(value)[i].item().toFloat();
        int maxi = std::get<1>(value)[i].item().toInt();
        results.push_back(std::make_pair(maxi, maxv));
    }
    return results;
}

/////////////////////////////////////////////////////////////////////////////////////////

DataLoader::DataLoader(const string& sampleFile, const string& labelFile)
{
    loadImages(sampleFile);
    loadLabels(labelFile);
}

std::tuple<torch::Tensor, torch::Tensor> DataLoader::batch(int size)
{
    torch::Tensor a = torch::zeros({ size, inputs.size(1) });
    torch::Tensor b = torch::zeros({ size, labels.size(1) });
    int whole = inputs.size(0);
    uniform_int_distribution<int> dist(0, whole - 1);
    int begin = dist(mt);
    for (int i = begin; i < begin + size; i++)
    {
        a[i - begin] = inputs[i % whole];
        b[i - begin] = labels[i % whole];
    }
    return { a, b };
}

tuple<torch::Tensor, torch::Tensor> DataLoader::all() const
{
    return { inputs, labels };
}

void DataLoader::loadImages(const string& file)
{
    fstream fs(file, ios::in | ios::binary);
    fs.seekg(0, ios::end);
    auto size = fs.tellg();
    fs.seekg(0, ios::beg);
    int magic;
    fs.read((char*)&magic, 4);
    magic = reverse(magic);
    int num;
    fs.read((char*)&num, 4);
    num = reverse(num);
    int w;
    fs.read((char*)&w, 4);
    w = reverse(w);
    int h;
    fs.read((char*)&h, 4);
    h = reverse(h);

    int length = w * h;
    vector<float> content;
    content.reserve(num * length);
    while (!fs.eof())
    {
        vector<char> byte(length);
        fs.read(byte.data(), byte.size());
        if (!fs.fail())
        {
            for (int j = 0; j < length; j++)
            {
                content.push_back((unsigned char)byte[j] / 255.0f);
            }
        }
    }
    fs.close();

    /* from_blob函数默认不拷贝内存，所以需要克隆一下 */
    inputs = torch::from_blob(content.data(), { num, length }, c10::ScalarType::Float).clone();

    //Mat image(h, w, CV_32FC1, inputs[0].data_ptr());
    //imshow("dsadad", image);
    //waitKeyEx();
}

void DataLoader::loadLabels(const string& file)
{
    fstream fs(file, ios::in | ios::binary);
    fs.seekg(0, ios::end);
    auto size = fs.tellg();
    fs.seekg(0, ios::beg);
    int magic;
    fs.read((char*)&magic, 4);
    magic = reverse(magic);
    int num;
    fs.read((char*)&num, 4);
    num = reverse(num);
    vector<char> content;
    content.reserve((uint64_t)size - 8);
    while (!fs.eof())
    {
        char byte;
        fs.read(&byte, 1);
        if (!fs.fail())
        {
            content.push_back(byte);
            //cout << int(byte) << " ";
        }
    }
    fs.close();

    labels = torch::zeros({ num, 10 });
    for (int i = 0; i < num; i++)
    {
        labels[i][content[i]] = 1.0f;
    }
}

int DataLoader::reverse(int x)
{
    std::swap(0[(char*)&x], 3[(char*)&x]);
    std::swap(1[(char*)&x], 2[(char*)&x]);
    return x;
}

vector<int> labelToIndex(torch::Tensor labels)
{
    tuple<torch::Tensor, torch::Tensor> value = labels.max(1);
    torch::Tensor index = std::get<1>(value);
    int rows = index.size(0);
    vector<int> myData;
    myData.reserve(rows);
    for (int i = 0; i < rows; i++)
    {
        myData.push_back(index[i].item().toInt());
    }
    return myData;
}

int main()
{
    DataLoader sample("../Dataset/train-images.idx3-ubyte", "../Dataset/train-labels.idx1-ubyte");

    Mlp machine(784, 10);

    /* 训练过程 */
    torch::optim::SGD optim(machine.parameters(), torch::optim::SGDOptions(0.2));
    torch::nn::CrossEntropyLoss lossFunc;
    machine.train();
    for (int i = 0; i < 400000; i++)
    {
        std::tuple<torch::Tensor, torch::Tensor> once = sample.batch(40);
        torch::Tensor predict = machine.forward(std::get<0>(once));
        torch::Tensor loss = lossFunc(predict, std::get<1>(once));
        optim.zero_grad();
        loss.backward();
        optim.step();
        if (i % 5000 == 0)
        {
            /* 每5000次循环输出一次损失函数值 */
            cout << "LOOP:" << i << ",LOSS=" << loss.item() << endl;
        }
    }

    machine.eval();

    /* 验证数据 */
    DataLoader check("../Dataset/t10k-images.idx3-ubyte", "../Dataset/t10k-labels.idx1-ubyte");
    tuple<torch::Tensor, torch::Tensor> test = check.all();
    vector<pair<int, float>> result = machine.predict(std::get<0>(test));
    vector<int> trueClass = labelToIndex(std::get<1>(test));
    int count = (int)result.size();
    for (int i = 0; i < count; i++)
    {
        cout << "真实=" << trueClass[i] << ",预测的=" << result[i].first << "(" << result[i].second << ")" << endl;
    }

    /* 计算正确率 */
    int trueCount = 0;
    for (int i = 0; i < count; i++)
    {
        trueCount += (trueClass[i] == result[i].first);
    }
    cout << "正确率=" << 1.0f * trueCount / count << endl;

    //Mat image(28, 28, CV_32FC1, std::get<0>(test)[10].data_ptr());
    //imshow("dasdasd", image);
    //resizeWindow("dasdasd", Size(400, 400));
    //waitKey();

    int z = 0;
    return 0;
}





